# -*- coding:utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
import torch
from utils import *
import torch.optim as optim
from Dataset import *
from predictor import *
from tqdm import tqdm
import time
import matplotlib.colors as mcolors
import os
import torch.optim.lr_scheduler as lr_scheduler
import logging
import json
import datetime


def configure_logging(log_file):
    logging.basicConfig(
                    level=logging.DEBUG,  
                    format='%(asctime)s - %(levelname)s - %(message)s',  
                    filename='/load_balancing.log',  
                    filemode='w'  
                )
    logging.info("Logging started.")



def test(test_dataset, test_dataloader, criterion, device):

    criterion = nn.MSELoss()
    model = torch.load(r'pth/tune_model/chicago/20250421_171046/1_5%.pth',map_location=device)
    model.eval()
    mask = np.load(r'/mask.npy')  
    mask_tensor = torch.tensor(mask, dtype=torch.bool).to(device)
    all_predictions = []
    all_labels = []
    all_gate = []
    total_mse = 0.0
    total_mae = 0.0
    total_samples = 0
    with torch.no_grad():
        for data, label,feature_idx in test_dataloader:
            data = data.to(device)
            output, gate_output, leaf_expert_ids,_ = model(data,  prompt_flag=1)

            feature_idx = feature_idx[0]
            output = test_dataset.data_denormalization(output,feature_idx)
            label = test_dataset.data_denormalization(label,feature_idx)
            output = torch.where(output>0,output,0)
            label = label.to(device)
            mask_t = mask_tensor.unsqueeze(0).unsqueeze(-1).expand_as(label)
            mask_flat = mask_t.view(label.size(0), -1)
            output_flat = output.view(output.size(0), -1)
            label_flat =label.view(label.size(0), -1)
     
            mse = torch.mean((output_flat* mask_flat - label_flat) ** 2, dim=1)  
            mae = torch.mean(torch.abs(output_flat* mask_flat - label_flat), dim=1) 
            all_predictions.append((output*mask_t).cpu())
            all_labels.append(label.cpu())
            all_gate.append(gate_output.cpu())

            total_mse += torch.sum(mse).item()
            total_mae += torch.sum(mae).item()

            total_samples += len(label)

    avg_mse = total_mse / total_samples
    avg_mae = total_mae / total_samples

    print(f'Overall Test MSE: {avg_mse} MAE:{avg_mae}')
    return avg_mse, avg_mse

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
   
    input_data = np.load(r'/data.npy')
    input_data = np.clip(input_data, 0, None)
    input_data = torch.from_numpy(input_data).to(device).float().unsqueeze(-1)

    batch_size = 32


    train_size = int(0.6 * len(input_data))
    valid_size = int(0.2 * len(input_data))
    test_size = len(input_data) - train_size - valid_size

    x_train_data = input_data[:train_size]
    x_valid_data = input_data[train_size:train_size + valid_size]
    x_test_data = input_data[train_size + valid_size:]
    #x_test_data = test_data

    train_dataset = preDataset(x_train_data)
    valid_dataset = preDataset(x_valid_data)
    test_dataset = preDataset(x_test_data)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,collate_fn=train_dataset.collate_fn)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False,collate_fn=valid_dataset.collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,collate_fn=test_dataset.collate_fn)

    criterion = DownStream_DistanceLoss()
    test(test_dataset,test_dataloader, criterion, device)

if __name__ == '__main__':
    main()